{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8f5b1622",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/constantinpape/torch-em/blob/main/experiments/3D-UNet-Training.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b5781c6",
   "metadata": {},
   "source": [
    "# 3D UNet Training\n",
    "\n",
    "This notebook implements training of a 3D UNet with `torch_em`. It implements training for affinity, boundary and foreground prediction. The training code is also customizable to enable training with a different target.\n",
    "\n",
    "The sections of this notebook are organized s.t. there is a first cell that contains the configurable parameters. These cells are marked by the comment `#CONFIGURE ME`. The following cells contain code that do not need to be changed. But if you know what you're doing you can further customize the training there.\n",
    "\n",
    "For setting up a local environment that can run `torch_em`, follow [the installation instructions](https://github.com/constantinpape/torch-em#installation). You can also run the notebook in google colab, please follow the instructions in the first section for this."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c0304fe",
   "metadata": {},
   "source": [
    "## Google Colab\n",
    "\n",
    "Run the following cells if you are working in google colab. Skip them otherwise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43a6c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, make sure that you are using a GPU. For this, go to:\n",
    "# Runtime->Change runtime type and select Hardware accelarator->GPU\n",
    "# When you then run this cell you should see a gpu status overview\n",
    "# (if something went wrong you will see an error message)\n",
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a99e08cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install conda in your google drive session\n",
    "!pip install -q condacolab\n",
    "import condacolab\n",
    "condacolab.install()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99fb1f76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# update conda\n",
    "!conda update conda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e32e2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use conda to install torch-em dependencies (we don't install torch-em itseelf, in order to avoid reinstallation of pytorch)\n",
    "!conda install -c pytorch -c conda-forge python-elf dask bioimageio.core"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0facc146",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use pip to install dependencies that need pytorch, including torch-em\n",
    "!pip install --no-deps kornia\n",
    "!pip install --no-deps git+https://github.com/constantinpape/torch-em"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c094870c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mount your google drive to permanently save data\n",
    "from google.colab import drive\n",
    "drive.mount(\"/content/gdrive\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ee44253",
   "metadata": {},
   "source": [
    "## Imports\n",
    "\n",
    "Import the required libraries, no need to change anything here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c72b8bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load tensorboard extension\n",
    "# we will need this later in the notebook to monitor the training progress\n",
    "%load_ext tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92be7f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch_em\n",
    "import torch_em.data.datasets as torchem_data\n",
    "from torch_em.model import AnisotropicUNet\n",
    "from torch_em.util.debug import check_loader, check_trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b14817f",
   "metadata": {},
   "source": [
    "## Training Data\n",
    "\n",
    "Choose the dataset for training. You have two options here: choose a preconfigured dataset or specify the filepaths to your training data.\n",
    "\n",
    "For preconfigured datasets `torch_em` implements download and default pre-processing already. The following pre-configured datasets are currently available:\n",
    "- `cremi`: Neural tissue imaged in electron microscopy.\n",
    "- `mitoem`: Mitochondria imaged in electron microscopy.\n",
    "- `patynereis-cells`: Platynereis cells imaged in electron microscopy.\n",
    "- `patynereis-nuclei`: Platynereis nuclei imaged in electron microscopy.\n",
    "- `snemi`: Neural tissue imaged in electron microscopy.\n",
    "\n",
    "If you're unsure if one of these datasets is suitable for you, just select it and continue to `Check training data` with the default settings in the next sections. You will see example images from the data there.\n",
    "\n",
    "You can also load the training data from local files. `torch_em` supports loading 3d data from hdf5, zarr or n5. For this:\n",
    "- set `data_paths` to the filepath of the volumetric image data and `label_paths` to the path of the volumetric label data stack; you can also provide a list of file paths if you have multiple stacks\n",
    "- set `data_key` and `label_key` to the corresponding paths in file.\n",
    "    \n",
    "You can find an example for using a custom dataset in the comments below.\n",
    "\n",
    "You also need to choose `patch_shape`, which determines the size of the patches used for training, here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "968006e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "\n",
    "#\n",
    "# use a pre-configured dataset\n",
    "#\n",
    "\n",
    "# Specify a pre-configured dataset. Set to `None` in order to specify the training data via file-paths instead.\n",
    "preconfigured_dataset = None\n",
    "\n",
    "# example: use the pre-configured `snemi` dataset\n",
    "# preconfigured_dataset = \"snemi\"\n",
    "\n",
    "# Where to download the training data (the data will be downloaded only once).\n",
    "# If you work in google colab you may want to adapt this path to be on your google drive, in order\n",
    "# to not loose the data after each session.\n",
    "download_folder = f\"./training_data/{preconfigured_dataset}\"\n",
    "\n",
    "#\n",
    "# use a custom dataset\n",
    "#\n",
    "\n",
    "# Create a custom dataset from local data by specifiying the paths for training data, training labels\n",
    "# as well as validation data and validation labels\n",
    "train_data_paths = []\n",
    "val_data_paths = []\n",
    "data_key = \"\"\n",
    "train_label_paths = []\n",
    "val_label_paths = []\n",
    "label_key = \"\"\n",
    "\n",
    "# In addition you can also specify region of interests for training using the normal python slice syntax\n",
    "train_rois = None\n",
    "val_rois = None\n",
    "\n",
    "# example: Use training data and labels stored as a single stack in an hdf5 file.\n",
    "# This example is formulated using the data from the `snemi` dataset,\n",
    "# which stores the raw data in `/raw` and the labels in `/labels/mitochondria`.\n",
    "# Note that we use roi's here to get separate training and val data from the same file.\n",
    "# train_data_paths etc. can also be lists in order to train from multiple stacks.\n",
    "\n",
    "# train_data_paths = train_label_paths = val_data_paths = val_label_paths = \"./training_data/snemi/snemi_train.h5\"\n",
    "# data_key = \"volumes/raw\"\n",
    "# label_key = \"volumes/labels/neuron_ids\"\n",
    "# train_rois = np.s_[:68, :, :]\n",
    "# val_rois = np.s_[68:, :, :]\n",
    "# patch_shape = (32, 256, 256)\n",
    "\n",
    "#\n",
    "# choose the patch shape\n",
    "#\n",
    "patch_shape = (32, 256, 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "436c227f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_names = [\n",
    "    \"cremi\", \"mitoem\", \"platynereis-cells\", \"platynereis-nuclei\", \"snemi\"\n",
    "]\n",
    "\n",
    "def check_data(data_paths, label_paths, rois):\n",
    "    print(\"Loading the raw data from:\", data_paths, data_key)\n",
    "    print(\"Loading the labels from:\", label_paths, label_key)\n",
    "    try:\n",
    "        torch_em.default_segmentation_dataset(data_paths, data_key, label_paths, label_key, patch_shape, rois=rois)\n",
    "    except Exception as e:\n",
    "        print(\"Loading the dataset failed with:\")\n",
    "        raise e\n",
    "\n",
    "if preconfigured_dataset is None:\n",
    "    print(\"Using a custom dataset:\")\n",
    "    print(\"Checking the training dataset:\")\n",
    "    check_data(train_data_paths, train_label_paths, train_rois)\n",
    "    check_data(val_data_paths, val_label_paths, val_rois)\n",
    "else:\n",
    "    assert preconfigured_dataset in dataset_names, f\"Invalid pre-configured dataset: {preconfigured_dataset}, choose one of {dataset_names}.\"\n",
    "\n",
    "assert len(patch_shape) == 3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70f4b659",
   "metadata": {},
   "source": [
    "## Network output\n",
    "\n",
    "Choose the transformations applied to your label data in order to generate the network target data. If your labels can be fed to the network directly you don't need to do anything here. Otherwise, you can choose between the following transformations:\n",
    "- `foreground`: transforms labels into a binary target\n",
    "- `affinities`: transforms labels into affinity target\n",
    "- `boundaries`: transforms labels into boundary target\n",
    "\n",
    "Note that `affinities` and `boundaries` are mutually exclusive; `foreground` can be combined with the two other transformations. All three transformations are implemented to be applied to *instance labels*. See the screenshot below for an illustration of the expected input and the result of these transformations. You will see their result for your own datain in `Check training data`.\n",
    "\n",
    "![targets](misc/targets.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6f1e790",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "\n",
    "# Whether to add a foreground channel (1 for all labels that are not zero) to the target.\n",
    "foreground = False\n",
    "# Whether to add affinity channels (= directed boundaries) or a boundary channel to the target.\n",
    "# Note that you can choose at most of these two options.\n",
    "affinities = False\n",
    "boundaries = False\n",
    "\n",
    "# the pixel offsets that are used to compute the affinity channels\n",
    "offsets = [\n",
    "    [-1, 0, 0], [0, -1, 0], [0, 0, -1],\n",
    "    [-2, 0, 0], [0, -3, 0], [0, 0, -3],\n",
    "    [-3, 0, 0], [0, -9, 0], [0, 0, -9]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa04e8e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert not (affinities and boundaries), \"Predicting both affinities and boundaries is not supported\"\n",
    "\n",
    "label_transform, label_transform2 = None, None\n",
    "if affinities:\n",
    "    label_transform2 = torch_em.transform.label.AffinityTransform(\n",
    "        offsets=offsets, add_binary_target=foreground, add_mask=True\n",
    "    )\n",
    "elif boundaries:\n",
    "    label_transform = torch_em.transform.label.BoundaryTransform(\n",
    "        add_binary_target=foreground\n",
    "    )\n",
    "elif foreground:\n",
    "    label_transform = torch_em.transform.label.labels_to_binary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff3bcb05",
   "metadata": {},
   "source": [
    "## Loss, metric & batch size\n",
    "\n",
    "Choose important training parameters:\n",
    "\n",
    "- `loss`: the loss function; can be one of `\"bce\", \"ce\", \"dice\"` (binary cross entropy, cross entropy, dice) or a torch module\n",
    "- `metric`: the metric used for the validation data; same options as for `loss`\n",
    "- `batch_size`: the training batch size\n",
    "\n",
    "If you're unsure about these settings just use the default values, they are probably ok."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76bef2e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "batch_size = 1\n",
    "loss = \"dice\"\n",
    "metric = \"dice\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a4ba21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_loss(loss_name):\n",
    "    loss_names = [\"bce\", \"ce\", \"dice\"]\n",
    "    if isinstance(loss_name, str):\n",
    "        assert loss_name in loss_names, f\"{loss_name}, {loss_names}\"\n",
    "        if loss_name == \"dice\":\n",
    "            loss_function = torch_em.loss.DiceLoss()\n",
    "        elif loss == \"ce\":\n",
    "            loss_function = nn.CrossEntropyLoss()\n",
    "        elif loss == \"bce\":\n",
    "            loss_function = nn.BCEWithLogitsLoss()\n",
    "    else:\n",
    "        loss_function = loss_name\n",
    "    \n",
    "    # we need to add a loss wrapper for affinities\n",
    "    if affinities:\n",
    "        loss_function = torch_em.loss.LossWrapper(\n",
    "            loss_function, transform=torch_em.loss.ApplyAndRemoveMask()\n",
    "        )\n",
    "    return loss_function\n",
    "\n",
    "\n",
    "loss_function = get_loss(loss)\n",
    "metric_function = get_loss(metric)\n",
    "\n",
    "kwargs = dict(\n",
    "    ndim=3, patch_shape=patch_shape, batch_size=batch_size,\n",
    "    label_transform=label_transform, label_transform2=label_transform2\n",
    ")\n",
    "ds = preconfigured_dataset\n",
    "\n",
    "# TODO add support for plantseg loader\n",
    "if ds is None:\n",
    "    train_loader = torch_em.default_segmentation_loader(\n",
    "        train_data_paths, data_key, train_label_paths, label_key,\n",
    "        rois=train_rois, **kwargs\n",
    "    )\n",
    "    val_loader = torch_em.default_segmentation_loader(\n",
    "        val_data_paths, data_key, val_label_paths, label_key,\n",
    "        rois=val_rois, **kwargs\n",
    "    )\n",
    "else:\n",
    "    kwargs.update(dict(download=True))\n",
    "    if ds == \"cremi\":\n",
    "        assert not foreground, \"Foreground prediction for neuron segmentation does not make sense, please change these setings\"\n",
    "        train_samples = (\"A\", \"B\", \"C\")\n",
    "        val_samples = (\"C\",)\n",
    "        train_rois = (np.s_[:, :, :], np.s_[:, :, :], np.s_[:75, :, :])\n",
    "        val_rois = (np.s_[75:, :, :],)\n",
    "        train_loader = torchem_data.get_cremi_loader(download_folder, samples=train_samples, rois=train_rois, **kwargs)\n",
    "        val_loader = torchem_data.get_cremi_loader(download_folder, samples=val_samples, rois=val_rois, **kwargs)\n",
    "    elif ds == \"mitoem\":\n",
    "        train_loader = torchem_data.get_mitoem_loader(download_folder, splits=\"train\", **kwargs)\n",
    "        val_loader = torchem_data.get_mitoem_loader(download_folder, splits=\"val\", **kwargs)\n",
    "    elif ds == \"platynereis-cells\":\n",
    "        train_samples = list(range(1, 10))\n",
    "        val_samples = [9]\n",
    "        train_rois = {9: np.s_[:, :600, :]}\n",
    "        val_rois = {9: np.s_[:, 600:, :]}\n",
    "        train_loader = torchem_data.get_platynereis_cell_loader(\n",
    "            download_folder, sample_ids=train_samples, rois=train_rois, **kwargs\n",
    "        )\n",
    "        val_loader = torchem_data.get_platynereis_cell_loader(\n",
    "            download_folder, sample_ids=val_samples, rois=val_rois, **kwargs\n",
    "        )\n",
    "    elif ds == \"platynereis-nuclei\":\n",
    "        train_samples = [1, 3, 6, 7, 8, 9, 10, 11, 12]\n",
    "        val_samples = [2, 4]\n",
    "        train_loader = torchem_data.get_platynereis_nuclei_loader(\n",
    "            download_folder, sample_ids=train_samples, **kwargs\n",
    "        )\n",
    "        val_loader = torchem_data.get_platynereis_nuclei_loader(\n",
    "            download_folder, sample_ids=val_samples, **kwargs\n",
    "        )\n",
    "    elif ds == \"snemi\":\n",
    "        assert not foreground, \"Foreground prediction for neuron segmentation does not make sense, please change these setings\"\n",
    "        n_slices = 100\n",
    "        z = n_slices - patch_shape[0]\n",
    "        train_roi, val_roi = np.s_[:z, :, :], np.s_[z:, :, :]\n",
    "        train_loader = torchem_data.get_snemi_loader(download_folder, sample=\"train\", rois=train_roi, **kwargs)\n",
    "        val_loader = torchem_data.get_snemi_loader(download_folder, sample=\"train\", rois=val_roi, **kwargs)\n",
    "\n",
    "assert train_loader is not None, \"Something went wrong\"\n",
    "assert val_loader is not None, \"Something went wrong\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ef2cc6f",
   "metadata": {},
   "source": [
    "## Check training data\n",
    "\n",
    "Check the output from your data loader. It consists of the input data for your network and the target data.\n",
    "You should check that the target looks reasonable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "445bb0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "\n",
    "# Choose the number of samples to check per loader.\n",
    "n_samples = 2\n",
    "\n",
    "# Whether to use napari or matplotlib to view the training data.\n",
    "# Napari can display 3d volumes, but it might not be installed by default and is not available on google colab.\n",
    "use_napari = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d5d5f12",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import napari\n",
    "except Exception:\n",
    "    print(\"use_napari was set to True, but napari is not availabel, using matplotlib as fallback solution\")\n",
    "    use_napari = False\n",
    "\n",
    "print(\"Training samples\")\n",
    "check_loader(train_loader, n_samples, plt=not use_napari)\n",
    "print(\"Validation samples\")\n",
    "check_loader(val_loader, n_samples, plt=not use_napari)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd5f1b55",
   "metadata": {},
   "source": [
    "## Network architecture\n",
    "\n",
    "Choose the important network architecture parameters for the 2d UNet:\n",
    "- `scale_factors`: the down/upscaling factors between each encoder/decoder level. This is specified as a list of 3d scale factors, which enables anisotropic scaling. See the example in the comments for details.\n",
    "- `initial_features`: the number of features in the first encoder level, the number will be doubled for each level\n",
    "- `final_activation`: the activation applied to the UNet output\n",
    "- `in_channels`: the number of input channels (= number of channels of the raw data)\n",
    "- `out_channels`: the number of output channels (usually the number of target channels)\n",
    "\n",
    "If you're unsure about these settings just use the default values, they are probably ok."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b899e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "\n",
    "# example for isotropic scaling with a depth of 4\n",
    "# scale_factors = 4 * [[2, 2, 2]]\n",
    "\n",
    "# example for 4 levels with anisotropic scaling in the first two levels (scale only in xy)\n",
    "scale_factors = [[1, 2, 2], [1, 2, 2], [2, 2, 2], [2, 2, 2]]\n",
    "\n",
    "initial_features = 32\n",
    "final_activation = None\n",
    "\n",
    "# If you leave the in/out_channels as None an attempt will be made to automatically deduce these numbers. \n",
    "in_channels = None\n",
    "out_channels = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b3e93cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "if final_activation is None and loss == \"dice\":\n",
    "    final_activation = \"Sigmoid\"\n",
    "    print(\"Adding a sigmoid activation because we are using dice loss\")\n",
    "\n",
    "if in_channels is None:\n",
    "    in_channels = 1\n",
    "\n",
    "if out_channels is None:\n",
    "    if affinities:\n",
    "        n_off = len(offsets)\n",
    "        out_channels = n_off + 1 if foreground else n_off\n",
    "    elif boundaries:\n",
    "        out_channels = 2 if foreground else 1\n",
    "    elif foreground:\n",
    "        out_channels = 1\n",
    "    assert out_channels is not None, \"The number of out channels could not be deduced automatically. Please set it manually in the cell above.\"\n",
    "\n",
    "print(\"Creating 3d UNet with\", in_channels, \"input channels and\", out_channels, \"output channels.\")\n",
    "model = AnisotropicUNet(\n",
    "    in_channels=in_channels, out_channels=out_channels, scale_factors=scale_factors, final_activation=final_activation\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85d6f2e9",
   "metadata": {},
   "source": [
    "## Tensorboard\n",
    "\n",
    "Start the tensorboard in order to keep track of the training progress."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b94a0d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%tensorboard --logdir logs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95456878",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "Choose additional training parameter:\n",
    "- `experiment_name`: the name for this training run, which will be used for naming the checkpoint for this model and for identifying the model in tensorboard.\n",
    "- `n_iterations`: number of iterations to train for.\n",
    "- `learning_rate`: the learning rate for gradient based updates.\n",
    "\n",
    "This also starts the training!\n",
    "\n",
    "**Important:** If you're on google colab the checkpoint will not be saved permanently. To save it you will need to copy the local folder `checkpoints` to your google drive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4168785",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "experiment_name = \"my-shiny-net\"\n",
    "n_iterations = 10000\n",
    "learning_rate = 1.0e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "738cd81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# IMPORTANT! if your session on google colab crashes here, you will need to uncomment the 'logger=None' comment\n",
    "# in this case you can't use tensorboard, but everything else will work as expected\n",
    "# (this happens due to incompatible google protobuf versions and I don't have time to fix this right now)\n",
    "trainer = torch_em.default_segmentation_trainer(\n",
    "    name=experiment_name, model=model,\n",
    "    train_loader=train_loader, val_loader=val_loader,\n",
    "    loss=loss_function, metric=metric_function,\n",
    "    learning_rate=learning_rate,\n",
    "    mixed_precision=True,\n",
    "    log_image_interval=50,\n",
    "    # logger=None\n",
    ")\n",
    "trainer.fit(n_iterations)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21c0b67f",
   "metadata": {},
   "source": [
    "## Check trained network\n",
    "\n",
    "Look at predictions from the trained network and their comparison to the target."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e949a598",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "\n",
    "# Choose the number of samples to check.\n",
    "n_samples = 2\n",
    "\n",
    "# Whether to use napari or matplotlib to view the training data.\n",
    "# Napari can display 3d volumes, but it might not be installed by default and is not available on google colab.\n",
    "use_napari = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd229f14",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import napari\n",
    "except Exception:\n",
    "    print(\"use_napari was set to True, but napari is not availabel, using matplotlib as fallback solution\")\n",
    "    use_napari = False\n",
    "\n",
    "check_trainer(trainer, n_samples, plt=not use_napari)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae96b202",
   "metadata": {},
   "source": [
    "## Export network to bioimage.io format\n",
    "\n",
    "Finally, you can export the trained model in the format compatible with [BioImage.IO](https://bioimage.io/#/), a modelzoo for bioimage analysis. After exporting, you can upload the model there to share it with other researchers.\n",
    "You only need to configure where to save the model via `export_folder` and whether to convert it to additional formats via `additional_weight_formats` and you will be prompted to enter additional information when you run the second cell. You should also write some documentation for the model in `doc`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "183b5afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURE ME\n",
    "\n",
    "# The folder where the bioimageio model will be saved (as a .zip file).\n",
    "# If you run in google colab you should adapt this path to your google drive so that you can download the saved model.\n",
    "export_folder = \"./my-fancy-bio-model\"\n",
    "\n",
    "# Whether to convert the model weights to additional formats.\n",
    "# Currently, torchscript and onnx are support it and this will enable running the model\n",
    "# in more software tools.\n",
    "additional_weight_formats = None\n",
    "# additional_weight_formats = [\"torchscript\"]\n",
    "\n",
    "doc = None\n",
    "# write some markdown documentation like this, otherwise a default documentation text will be used\n",
    "# doc = \"\"\"#My Fancy Model\n",
    "# This is a fancy model to segment shiny objects in images.\n",
    "# \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c88f6f42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch_em.util.modelzoo\n",
    "\n",
    "for_dij = additional_weight_formats is not None and \"torchscript\" in additional_weight_formats\n",
    "\n",
    "training_data = None\n",
    "if preconfigured_dataset is not None:\n",
    "    if preconfigured_dataset.startswith(\"platynereis\"):\n",
    "        data_id = torchem_data.get_bioimageio_dataset_id(\"platynereis\")\n",
    "    else:\n",
    "        data_id = torchem_data.get_bioimageio_dataset_id(preconfigured_dataset)\n",
    "    if data_id:\n",
    "        training_data = {\"id\": data_id}\n",
    "\n",
    "pred_str = \"\"\n",
    "if affinities:\n",
    "    pred_str = \"affinities and foreground probabilities\" if foreground else \"affinities\"\n",
    "elif boundaries:\n",
    "    pred_str = \"boundary and foreground probabilities\" if foreground else \"boundaries\"\n",
    "elif foreground:\n",
    "    pred_str = \"foreground\"\n",
    "\n",
    "default_doc = f\"\"\"#{experiment_name}\n",
    "\n",
    "This model was trained with [the torch_em 3d UNet notebook](https://github.com/constantinpape/torch-em/blob/main/experiments/3D-UNet-Training.ipynb).\n",
    "\"\"\"\n",
    "if pred_str:\n",
    "    default_doc += f\"It predicts {pred_str}.\\n\"\n",
    "\n",
    "training_summary = torch_em.util.get_training_summary(trainer, to_md=True, lr=learning_rate)\n",
    "default_doc += f\"\"\"## Training Schedule\n",
    "\n",
    "{training_summary}\n",
    "\"\"\"\n",
    "\n",
    "if doc is None:\n",
    "    doc = default_doc\n",
    "\n",
    "torch_em.util.modelzoo.export_bioimageio_model(\n",
    "    trainer, export_folder, input_optional_parameters=True,\n",
    "    for_deepimagej=for_dij, training_data=training_data, documentation=doc\n",
    ")\n",
    "torch_em.util.modelzoo.add_weight_formats(export_folder, additional_weight_formats)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
